{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms\n",
    "# torch.autograd.set_detect_anomaly(True)\n",
    "import numpy as np\n",
    "import yaml\n",
    "from functools import partial\n",
    "from time import gmtime, strftime, time\n",
    "from sklearn.metrics import classification_report\n",
    "# declare batch size for act functionop\n",
    "params = yaml.safe_load(open('resAE_parameters.yaml'))\n",
    "\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = True  # Test this with False\n",
    "torch.backends.cudnn.benchmark = False\n",
    "np.random.seed(0) \n",
    "device = torch.device(\"cuda\")\n",
    "batch_size = params['batch_size']"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Basic block\n",
    "\n",
    "class Basicblock(nn.Module):\n",
    "    expansion = 1\n",
    "    def __init__(self, input_planes, planes, stride=1, dim_change=None):\n",
    "        super(Basicblock,self).__init__()\n",
    "        # Declare convolutional layers with batch norms\n",
    "        self.conv1 = nn.Conv2d(input_planes, planes, stride=stride, kernel_size=3, padding=1)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, stride=1, kernel_size=3, padding=1)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.relu = nn.ReLU(True)\n",
    "        self.dim_change = dim_change\n",
    "        \n",
    "    def forward(self, x):\n",
    "        # Save the residual\n",
    "        res = x\n",
    "        output = self.conv1(x)\n",
    "        output = self.bn1(output)\n",
    "        output = self.relu(output)\n",
    "        output = self.conv2(output)\n",
    "        output = self.bn2(output)\n",
    "        \n",
    "        if self.dim_change is not None:\n",
    "            # print(\"res before : \", res.size())\n",
    "            # print(self.dim_change)\n",
    "            res = self.dim_change(res)\n",
    "            # print(\"res after : \", res.size())\n",
    "            # print(\"output size : \", output.size())\n",
    "        output += res\n",
    "        output = self.relu(output)\n",
    "        \n",
    "        return output\n",
    "    "
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Res Encoder\n",
    "\n",
    "class ResEncoder(nn.Module):\n",
    "    def __init__(self, block, num_layers, classes=2):\n",
    "        super(ResEncoder, self).__init__()\n",
    "        self.input_planes = 8\n",
    "        # First layer is same\n",
    "        self.conv1 = nn.Conv2d(in_channels=3,out_channels=8,kernel_size=3,padding=1, stride=1)\n",
    "        self.relu = nn.ReLU(True)\n",
    "        \n",
    "        # Here comes the blocks\n",
    "        self.layer1 = self._layer(block, 16, num_layers[0], stride=2)\n",
    "        self.layer2 = self._layer(block, 32, num_layers[1], stride=2)\n",
    "        self.layer3 = self._layer(block, 64, num_layers[2], stride=2)\n",
    "        self.last_conv1 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,stride=2,padding=1)\n",
    "        self.last_bn1 =nn.BatchNorm2d(128)\n",
    "        self.last_leaky_relu = nn.LeakyReLU(0.0000001)\n",
    "        \n",
    "    def _layer(self, block, planes, num_layers, stride=1):\n",
    "        dim_change = None\n",
    "        if stride != 2 or planes != self.input_planes*block.expansion:\n",
    "            dim_change = nn.Sequential(\n",
    "                nn.Conv2d(self.input_planes, planes*block.expansion, kernel_size=1, stride=stride),\n",
    "                nn.BatchNorm2d(planes*block.expansion))\n",
    "            net_layers = []\n",
    "            net_layers.append(block(self.input_planes, planes, stride=stride, dim_change=dim_change))\n",
    "            self.input_planes = planes * block.expansion\n",
    "            for i in range(1, num_layers):\n",
    "                net_layers.append(block(self.input_planes, planes))\n",
    "                self.input_planes = planes * block.expansion\n",
    "            \n",
    "            return nn.Sequential(*net_layers)\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.relu(x)\n",
    "        \n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        \n",
    "        x = self.layer3(x)\n",
    "        \n",
    "        x = self.last_conv1(x)\n",
    "        x = self.last_bn1(x)\n",
    "        x = self.last_leaky_relu(x)\n",
    "        # print(\"size after encoder:\", x.size())\n",
    "        return x\n",
    "\n",
    "# Make an encoder with Bottleneck, num_layers with 3\n",
    "encoder = ResEncoder(Basicblock, [3, 3, 3])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Res Decoder\n",
    "\n",
    "class ResDecoder(nn.Module):\n",
    "    def __init__(self, block, num_layers, classes=2):\n",
    "        super(ResDecoder, self).__init__()\n",
    "        self.input_planes = 128\n",
    "        # Here comes the blocks\n",
    "        self.upsample = nn.Upsample(scale_factor=2,mode='nearest')\n",
    "        self.layer1 = self._layer(block, 64, num_layers[0], stride=1)\n",
    "        self.layer2 = self._layer(block, 32, num_layers[1], stride=1)\n",
    "        self.layer3 = self._layer(block, 16, num_layers[2], stride=1)\n",
    "        self.layer4 = self._layer(block, 8, num_layers[2], stride=1)\n",
    "        self.last_conv1 = nn.Conv2d(in_channels=8,out_channels=3,kernel_size=3,stride=1,padding=1)\n",
    "        self.tanh_10_2 =nn.Tanh()\n",
    "        \n",
    "        \n",
    "    def _layer(self, block, planes, num_layers, stride=2):\n",
    "        dim_change = None\n",
    "        if stride != 2 or planes != self.input_planes*block.expansion:\n",
    "            # print(\"self.input_planes:\", self.input_planes)\n",
    "            # print(\"planes*block.expansion:\", planes*block.expansion)\n",
    "            dim_change = nn.Sequential(\n",
    "                nn.Conv2d(self.input_planes, planes*block.expansion, kernel_size=1, stride=stride),\n",
    "                nn.BatchNorm2d(planes*block.expansion))\n",
    "            net_layers = []\n",
    "            net_layers.append(block(self.input_planes, planes, stride=stride, dim_change=dim_change))\n",
    "            self.input_planes = planes * block.expansion\n",
    "            for i in range(1, num_layers):\n",
    "                net_layers.append(block(self.input_planes, planes))\n",
    "                self.input_planes = planes * block.expansion\n",
    "            \n",
    "            return nn.Sequential(*net_layers)\n",
    "    def forward(self, x):\n",
    "        # print(\"Size before upsample :\", x.size())\n",
    "        x = self.upsample(x)\n",
    "        # print(\"Size after upsample :\", x.size())\n",
    "        x = self.layer1(x)\n",
    "        x = self.upsample(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.upsample(x)\n",
    "        x = self.layer3(x)\n",
    "        x = self.upsample(x)\n",
    "        x = self.layer4(x)\n",
    "        x = self.last_conv1(x)\n",
    "        x = self.tanh_10_2(x)\n",
    "        return x\n",
    "\n",
    "# Build a decoder\n",
    "decoder = ResDecoder(Basicblock, [3, 3, 3])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# ResAutoencoder\n",
    "class ResAutoencoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ResAutoencoder, self).__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        # hook for the gradients of the activations\n",
    "        # self.gradients = None\n",
    "        # self.feature_conv = self.encoder.last_leaky_relu\n",
    "        \n",
    "    def forward(self, x, label):        \n",
    "        x = self.encoder(x)\n",
    "        act = x.clone()\n",
    "        dep = x.clone()\n",
    "        \n",
    "        # Selection block setting zero values based on label\n",
    "        # [:64] -> fake data latent space\n",
    "        # [64:] -> real data latent space\n",
    "        # 0->fake, 1->real\n",
    "        A = nn.Parameter(torch.zeros(64,15,15))\n",
    "        \n",
    "        for i in range(len(label)):\n",
    "            # real\n",
    "            if label[i].item():\n",
    "                # setting fake latent space into zero\n",
    "                dep[i, :64] = A\n",
    "            else:\n",
    "                dep[i, 64:] = A\n",
    "                \n",
    "        x = self.decoder(x)\n",
    "        \n",
    "        return x, act\n",
    "    \n",
    "    # # hook for the gradients of the activation\n",
    "    # def activations_hook(self, grad):\n",
    "    #     self.gradients = grad\n",
    "    #     \n",
    "    # # method for the gradient extraction\n",
    "    # def get_activations_gradient(self):\n",
    "    #     return self.gradients\n",
    "    # \n",
    "    # # method for the activation exctraction\n",
    "    # def get_activations(self, x):\n",
    "    #     return self.features_conv(x)\n",
    "    \n",
    "resautoencoder = ResAutoencoder()\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "model = ResAutoencoder()\n",
    "model_path = ''\n",
    "model = model.load_state_dict(torch.load(model_path))\n",
    "model"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "model.cuda()\n",
    "\n",
    "learning_rate = params['learning_rate']\n",
    "optimizer = torch.optim.Adam(\n",
    "    model.parameters(), lr=learning_rate,eps=1e-7)\n",
    "\n",
    "#configuration\n",
    "num_epochs = params['num_epochs']\n",
    "criterion1 = nn.L1Loss()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Dataset generation\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "# With image augmentation\n",
    "train_dataset =torchvision.datasets.ImageFolder(root=\"D:\\labwork\\local_deepfake\\dataset/train\",\n",
    "                                                transform = transforms.Compose([\n",
    "                                                    transforms.Resize((240,240)),\n",
    "                                                    torchvision.transforms.ColorJitter(hue=.05, saturation=.05),\n",
    "                                                    torchvision.transforms.RandomHorizontalFlip(),\n",
    "                                                    torchvision.transforms.RandomRotation(20),\n",
    "                                                    transforms.ToTensor()]))\n",
    "print(\"Class labels : \", train_dataset.class_to_idx)\n",
    "\n",
    "validation_dataset =torchvision.datasets.ImageFolder(root=\"D:\\labwork\\local_deepfake\\dataset/val\",\n",
    "                                                transform = transforms.Compose([\n",
    "                                                    transforms.Resize((240,240)),\n",
    "                                                    transforms.ToTensor()]))\n",
    "\n",
    "train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,\n",
    "                                          batch_size=batch_size,\n",
    "                                          shuffle=True,\n",
    "                                           num_workers=4)\n",
    "validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,\n",
    "                                          batch_size=batch_size,\n",
    "                                          shuffle=True,\n",
    "                                          num_workers=4)\n",
    "\n",
    "print(\"train_dataset %d\" % len(train_dataset))\n",
    "print(\"validation set %d\" % len( validation_dataset))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Activation Loss\n",
    "def act_loss_func(outputs, labels):\n",
    "    batch_size = outputs.size()[0]\n",
    "    loss_list = torch.zeros([batch_size])\n",
    "    loss_list = loss_list.to(device)\n",
    "    # loss_list.cuda()\n",
    "    for i in range(batch_size):\n",
    "        #fake\n",
    "        total_loss =torch.zeros([1],dtype=torch.float32)\n",
    "        total_loss = total_loss.to(device)\n",
    "        # total_loss.cuda()\n",
    "        #real\n",
    "        total_loss_1 =torch.zeros([1],dtype=torch.float32)\n",
    "        total_loss_1 = total_loss_1.to(device)\n",
    "        # total_loss_1.cuda()\n",
    "        # real\n",
    "        if labels[i].item():\n",
    "            #fake\n",
    "            for latent_index in range(64):\n",
    "                temp= torch.sum(torch.abs(outputs[i,latent_index,:,:]))/225\n",
    "                total_loss = total_loss+temp\n",
    "            #real\n",
    "            for latent_index in range(64,128):\n",
    "                temp= torch.abs(1 -torch.sum(torch.abs(outputs[i,latent_index,:,:]))/14400)  # 15*15*64\n",
    "                total_loss_1 = total_loss_1+temp\n",
    "        #fake\n",
    "        else:\n",
    "            #fake\n",
    "            for latent_index in range(64):\n",
    "                temp= torch.abs(1- torch.sum(torch.abs(outputs[i,latent_index,:,:]))/14400)  # 15*15*64\n",
    "                total_loss = total_loss+temp\n",
    "            #real\n",
    "            for latent_index in range(64,128):\n",
    "                temp= torch.sum(torch.abs(outputs[i,latent_index,:,:]))/225\n",
    "                total_loss_1 = total_loss_1+temp\n",
    "        \n",
    "        loss_list[i]=total_loss+total_loss_1\n",
    "\n",
    "        \n",
    "    return torch.sum(loss_list)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "# test\n",
    "def act_loss_test(outputs):\n",
    "    batch_size = outputs.size()[0]\n",
    "    answer = torch.zeros([batch_size,2])\n",
    "    answer.cuda()\n",
    "    for i in range(batch_size):\n",
    "        fake = torch.zeros([1], dtype=torch.float32).to(device)\n",
    "        real = torch.zeros([1], dtype=torch.float32).to(device)\n",
    "        \n",
    "        # fake latent space\n",
    "        for latent_index in range(64):\n",
    "            # fake = fake + torch.sum(torch.abs(outputs[i, latent_index]))\n",
    "            fake = fake + torch.sum(torch.abs(outputs[i, latent_index]))/14400\n",
    "        # real latent space\n",
    "        for latent_index in range(64, 128):\n",
    "            # real = real + torch.sum(torch.abs(outputs[i, latent_index]))\n",
    "            real = real + torch.sum(torch.abs(outputs[i, latent_index]))/14400\n",
    "\n",
    "        answer[i][0] = fake.item() / (fake.item() + real.item())\n",
    "        answer[i][1] = real.item() / (fake.item() + real.item())\n",
    "           \n",
    "    return answer\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Base Dataset Training"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from torch.autograd import Variable\n",
    "import torchvision.utils as vutils\n",
    "from sklearn.metrics import classification_report\n",
    "import os\n",
    "\n",
    "target_names = ['real','fake']\n",
    "loss_val =0\n",
    "print(\"Start training\")\n",
    "checkpoint_dir = os.path.join('D:\\labwork\\local_deep_transfer\\source\\saved_models/checkpoints_kaggle', strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime()))\n",
    "for epoch in range(num_epochs):\n",
    "    run_loss = 0\n",
    "    run_act_loss = 0\n",
    "    model.train()\n",
    "    print(\"epoch : \", epoch)\n",
    "    start = time()\n",
    "    for i, (x,label) in enumerate(train_dataloader):\n",
    "        init = x\n",
    "        init = init.to(device)\n",
    "        x = x.view(x.size(),-1)\n",
    "        x = x.to(device)\n",
    "        \n",
    "        label = label.to(device)\n",
    "        output,act_data = model(x,label)\n",
    "        rec_loss = criterion1(output, init)\n",
    "        act_loss = act_loss_func(act_data, label)\n",
    "        loss = act_loss+0.1*rec_loss\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    end = time()\n",
    "    print('epoch [{}/{}], loss:{:.4f}'\n",
    "          .format(epoch + 1, num_epochs, loss.item()))\n",
    "    print(\"Took \", end - start)\n",
    "\n",
    "    model.eval()\n",
    "    \n",
    "    pred= []\n",
    "    labels= []\n",
    "    correct =0\n",
    "    total =0\n",
    "    with torch.no_grad():\n",
    "        loss = 0\n",
    "        for _, (x,label) in enumerate(validation_dataloader):\n",
    "            init = x\n",
    "            init= init.to(device)\n",
    "            x = x.view(x.size(),-1)\n",
    "            x = x.to(device)\n",
    "            a= label.shape[0]\n",
    "            temp = torch.rand([a])\n",
    "            temp = temp.to(device)\n",
    "            output,act_data = model(x,temp)\n",
    "            outputs  = act_loss_test(act_data)\n",
    "            \n",
    "            rec_loss = criterion1(output, init)\n",
    "            act_loss = act_loss_func(act_data, label)\n",
    "\n",
    "            loss += act_loss+0.1*rec_loss  \n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            pred += predicted.tolist()\n",
    "            labels += label.tolist()\n",
    "            correct += (predicted == label).sum().item()\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "\n",
    "        print(\"Validation Loss is %f\" % loss)\n",
    "        temp =correct/len(validation_dataset)\n",
    "        print('Validation Accuracy %f %%' % temp)\n",
    "        print(classification_report(labels, pred, target_names=target_names, digits=4))\n",
    "\n",
    "    model_name =\"/TAR_\" +str(epoch)+ 'epoch_.pth'\n",
    "    if not os.path.exists(checkpoint_dir):\n",
    "        os.makedirs(strftime(checkpoint_dir))\n",
    "    epoch_checkpoint_dir = checkpoint_dir + model_name\n",
    "    \n",
    "    torch.save(model.state_dict(), epoch_checkpoint_dir)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Test "
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Load the model\n",
    "model = ResAutoencoder()\n",
    "model.load_state_dict(torch.load(\"D:\\labwork\\local_deep_transfer\\source\\saved_models\\checkpoints/resAE_face2face/resAE_face2face_small_leaky_11epoch_.pth\"))\n",
    "print(model)\n",
    "model.cuda()\n",
    "model.eval()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "test_path = ''\n",
    "test =torchvision.datasets.ImageFolder(root=test_path, transform = transforms.Compose([transforms.Resize((240,240)),transforms.ToTensor()]))\n",
    "test_dataloader = DataLoader(test, batch_size=batch_size, shuffle=True)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "target_names = ['fake','real']\n",
    "pred= []\n",
    "labels= []\n",
    "correct =0\n",
    "total =0\n",
    "with torch.no_grad():\n",
    "    loss = 0\n",
    "    for _, (x, label) in enumerate(test_dataloader):\n",
    "        init = x\n",
    "        init= init.to(device)\n",
    "        x = x.view(x.size(),-1)\n",
    "        x = x.to(device)\n",
    "        a= label.shape[0]\n",
    "        temp = torch.rand([a])\n",
    "        output,act_data = model(x,temp)\n",
    "        outputs  = act_loss_test(act_data)\n",
    "        \n",
    "        rec_loss = criterion1(output, init)\n",
    "        act_loss = act_loss_func(act_data, label)\n",
    "\n",
    "        loss += act_loss+0.1*rec_loss  \n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        pred += predicted.tolist()\n",
    "        labels += label.tolist()\n",
    "        correct += (predicted == label).sum().item()\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    print(\"test Loss is %f\" % loss)\n",
    "    temp =correct/len(pred)\n",
    "    print('test Accuracy %f %%' % temp)\n",
    "    print(classification_report(labels, pred, target_names=target_names, digits=4))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}